{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_data(size=100000):\n",
    "    \"\"\"\n",
    "    生成数据:\n",
    "        输入数据X：在时间t，Xt的值有50%的概率为1，50%的概率为0；\n",
    "        输出数据Y：在实践t，Yt的值有50%的概率为1，50%的概率为0，除此之外，如果`Xt-3 == 1`，Yt为1的概率增加50%， 如果`Xt-8 == 1`，则Yt为1的概率减少25%， 如果上述两个条件同时满足，则Yt为1的概率为75%。\n",
    "    \"\"\"\n",
    "    X = np.random.choice(2,(size,))\n",
    "    Y = []\n",
    "    for i in range(size):\n",
    "        threshold = 0.5\n",
    "        # 判断X[i-3]和X[i-8]是否为1，修改阈值\n",
    "        if X[i-3] == 1:\n",
    "            threshold += 0.5\n",
    "        if X[i-8] == 1:\n",
    "            threshold -= 0.25\n",
    "        # 生成随机数，以threshold为阈值给Yi赋值\n",
    "        if np.random.rand() > threshold:\n",
    "            Y.append(0)\n",
    "        else:\n",
    "            Y.append(1)\n",
    "    return X,np.array(Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_batch(raw_data, batch_size, num_steps):\n",
    "    # raw_data是使用gen_data()函数生成的数据，分别是X和Y\n",
    "    raw_x, raw_y = raw_data\n",
    "    data_length = len(raw_x)\n",
    "    \n",
    "    # 首先将数据切分成batch_size份，0-batch_size，batch_size-2*batch_size。。。\n",
    "    batch_partition_length = data_length // batch_size\n",
    "    data_x = np.zeros([batch_size, batch_partition_length], dtype=np.int32)\n",
    "    data_y = np.zeros([batch_size, batch_partition_length], dtype=np.int32)\n",
    "    \n",
    "    # 因为RNN模型一次只处理num_steps个数据，所以将每个batch_size在进行切分成epoch_size份，每份num_steps个数据。注意这里的epoch_size和模型训练过程中的epoch不同。\n",
    "    for i in range(batch_size):\n",
    "        data_x[i] = raw_x[i * batch_partition_length:(i + 1) * batch_partition_length]\n",
    "        data_y[i] = raw_x[i * batch_partition_length:(i + 1) * batch_partition_length]\n",
    "        \n",
    "    # x是0-num_steps， batch_partition_length -batch_partition_length +num_steps。。。共batch_size个\n",
    "    epoch_size = batch_partition_length // num_steps\n",
    "    for i in range(epoch_size):\n",
    "        x = data_x[:, i * num_steps:(i + 1) * num_steps]\n",
    "        y = data_x[:, i * num_steps:(i + 1) * num_steps]\n",
    "        yield (x, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_epochs(n, num_steps):\n",
    "    '''这里的n就是训练过程中用的epoch，即在样本规模上循环的次数'''\n",
    "    for i in range(n):\n",
    "        yield gen_batch(gen_data(), batch_size, num_steps=num_steps)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0815 13:59:50.105675 139691886995264 deprecation.py:323] From <ipython-input-6-938f039f2ac0>:20: BasicRNNCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.SimpleRNNCell, and will be replaced by that in Tensorflow 2.0.\n",
      "W0815 13:59:50.107444 139691886995264 deprecation.py:323] From <ipython-input-6-938f039f2ac0>:21: static_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `keras.layers.RNN(cell, unroll=True)`, which is equivalent to this API\n",
      "W0815 13:59:50.110475 139691886995264 deprecation.py:323] From /usr/local/python3/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:455: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.add_weight` method instead.\n",
      "W0815 13:59:50.120152 139691886995264 deprecation.py:506] From /usr/local/python3/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:459: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\EPOCH 0\n",
      "Average loss at step 100 for last 100 steps: 0.45625598669052125\n",
      "Average loss at step 200 for last 100 steps: 0.33522686779499056\n",
      "Average loss at step 300 for last 100 steps: 0.32363319605588914\n",
      "Average loss at step 400 for last 100 steps: 0.3198443701863289\n",
      "Average loss at step 500 for last 100 steps: 0.31799371004104615\n",
      "Average loss at step 600 for last 100 steps: 0.316926654279232\n",
      "Average loss at step 700 for last 100 steps: 0.3162370014190674\n",
      "Average loss at step 800 for last 100 steps: 0.3157545381784439\n",
      "Average loss at step 900 for last 100 steps: 0.31540178656578066\n",
      "Average loss at step 1000 for last 100 steps: 0.3151378232240677\n",
      "Average loss at step 1100 for last 100 steps: 0.31492061764001844\n",
      "Average loss at step 1200 for last 100 steps: 0.314748175740242\n",
      "Average loss at step 1300 for last 100 steps: 0.3146084254980087\n",
      "Average loss at step 1400 for last 100 steps: 0.3144863533973694\n",
      "Average loss at step 1500 for last 100 steps: 0.3143901726603508\n",
      "Average loss at step 1600 for last 100 steps: 0.3143030619621277\n",
      "Average loss at step 1700 for last 100 steps: 0.31422749906778336\n",
      "Average loss at step 1800 for last 100 steps: 0.3141659554839134\n",
      "Average loss at step 1900 for last 100 steps: 0.31411062598228456\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAfcUlEQVR4nO3dfZAcd33n8fdnZnZn9DBjG2l3BZKMbCMgsp0Q3Z5tckAcHhQ5JLZ5SmRz4At3Ma5YFacODkyFcjmGVGGSOHdcdIACTrgAEZCDO9VZxCZHOBLONlobP0Q2imTZjqUYaSULS6unffreH9Oz21rvw0j73P15VW1N969/PfPd0ejTvT3961ZEYGZm2VWY7QLMzGx6OejNzDLOQW9mlnEOejOzjHPQm5llXGm2Cxhp6dKlsWrVqtkuw8xsXnnooYcORkTbaMvmXNCvWrWKrq6u2S7DzGxekfTsWMuaOnQjab2knZJ2S7p1nH7vkhSSOlNtPyvpfkk7JD0uqXJm5ZuZ2WRMuEcvqQhsAt4G7AW2S9oaEU+M6FcFbgEeTLWVgC8D74uIRyUtAfqmsH4zM5tAM3v0lwG7I2JPRPQCW4BrRun3CeBO4GSqbR3wWEQ8ChARhyJiYJI1m5nZGWgm6JcDz6Xm9yZtQyStBVZGxD0j1n01EJLulfSwpI+M9gKSbpTUJamru7v7DMo3M7OJTPr0SkkF4C7gQ6MsLgFvAN6bPL5D0ltGdoqIzRHRGRGdbW2jfmlsZmZnqZmg3wesTM2vSNoaqsAlwPckPQNcAWxNvpDdC3w/Ig5GxHFgG7B2Kgo3M7PmNBP024HVki6Q1ApsALY2FkbEixGxNCJWRcQq4AHg6ojoAu4FLpW0MPli9heBJ176EmZmNl0mDPqI6Ac2Ug/tJ4GvR8QOSXdIunqCdQ9TP6yzHXgEeHiU4/hTYt9PT3DXfTt55uCx6Xh6M7N5q6kBUxGxjfphl3TbbWP0vXLE/Jepn2I5rX56vJfPfHc3a15RY9XSRdP9cmZm80ZmrnXTUauPw9p/5NQsV2JmNrdkJuhftrCVUkHsP3Jy4s5mZjmSmaAvFERbtew9ejOzETIT9ADttQoHjnqP3swsLVNB31Etc8B79GZmp8lU0LfXyuz3Hr2Z2WkyFfQd1Qo/Pd7HqX5fN83MrCFbQZ+cYunDN2ZmwzIV9O21MoC/kDUzS8lW0Fc9aMrMbKRMBX1HY4/eg6bMzIZkKujPW9hKS1HsP+o9ejOzhkwFfaEg2qsVXwbBzCwlU0EP0FYt0+09ejOzIZkL+o5a2Xv0ZmYpGQz6is+6MTNLyVzQt1fLvHiij5N9Hh1rZgZZDPpkdKyP05uZ1WUu6IfvNOXj9GZm0GTQS1ovaaek3ZJuHaffuySFpM4R7edL6pH04ckWPJHGoCkfpzczq5sw6CUVgU3AVcAa4DpJa0bpVwVuAR4c5WnuAr49uVKb07gMgq93Y2ZW18we/WXA7ojYExG9wBbgmlH6fQK4EzgtYSVdCzwN7JhkrU05b2FLfXSs9+jNzIDmgn458Fxqfm/SNkTSWmBlRNwzon0x8FHg98d7AUk3SuqS1NXd3d1U4eM8F+3Viq93Y2aWmPSXsZIK1A/NfGiUxbcDfxIRPeM9R0RsjojOiOhsa2ubbEm+05SZWUqpiT77gJWp+RVJW0MVuAT4niSAZcBWSVcDlwPvlvRp4FxgUNLJiPjTqSh+LB3VCk91j7ttMTPLjWaCfjuwWtIF1AN+A3B9Y2FEvAgsbcxL+h7w4YjoAt6Yar8d6JnukIf6mTf/76mD0/0yZmbzwoSHbiKiH9gI3As8CXw9InZIuiPZa59z2msVjpzs50SvR8eamTWzR09EbAO2jWi7bYy+V47RfvsZ1nbW2qvDtxR85ZJFM/WyZmZzUuZGxkLqJuG+DIKZWbaD3pdBMDPLaNA3Dt140JSZWUaD/tyFLbQWC74MgpkZGQ16SbTXyhzwHr2ZWTaDHhp3mvIevZlZZoO+vep7x5qZQYaDvqNW8emVZmZkOOjba2WOnuzneG//bJdiZjarMhv0HY0bkPgLWTPLucwGfXutcRkEB72Z5Vtmg96jY83M6rIb9FUHvZkZZDjoawtKtJYKPnRjZrmX2aCXREet7HvHmlnuZTbooX74xhc2M7O8y3bQ1yq+SbiZ5V6mg76t6gubmZllOug7ahV6TvVz7JRHx5pZfjUV9JLWS9opabekW8fp9y5JIakzmX+bpIckPZ48vnmqCm9GhwdNmZlNHPSSisAm4CpgDXCdpDWj9KsCtwAPppoPAr8WEZcCNwB/ORVFN6vd59KbmTW1R38ZsDsi9kREL7AFuGaUfp8A7gSGUjUifhQR/5LM7gAWSCpPsuameY/ezKy5oF8OPJea35u0DZG0FlgZEfeM8zzvAh6OiJekrqQbJXVJ6uru7m6ipOa01xoXNvMevZnl16S/jJVUAO4CPjROn4up7+1/cLTlEbE5IjojorOtrW2yJQ2pVUpUWgo+dGNmudZM0O8DVqbmVyRtDVXgEuB7kp4BrgC2pr6QXQF8C3h/RDw1FUU3SxLtHjRlZjnXTNBvB1ZLukBSK7AB2NpYGBEvRsTSiFgVEauAB4CrI6JL0rnAPcCtEfGDaah/Qh21Mgc8aMrMcmzCoI+IfmAjcC/wJPD1iNgh6Q5JV0+w+kbgVcBtkh5JftonXfUZaK9VPGjKzHKt1EyniNgGbBvRdtsYfa9MTX8S+OQk6pu09mqZ7/kYvZnlWKZHxkJ9dOyx3gF6PDrWzHIqB0GfnEvvvXozy6nsB/3Q6FgfpzezfMp80A/fJNx79GaWTzkI+sboWO/Rm1k+ZT7oq+USC1qKHh1rZrmV+aBv3Dt2vy9sZmY5lfmgh/rlin3WjZnlVT6Cvlb2pYrNLLdyEfQdtQr7j5wkIma7FDOzGZeLoG+vljnu0bFmllO5CPqOximWPnxjZjmUi6BvDJryKZZmlke5CPoOD5oysxzLRdC3V71Hb2b5lYugX1wusbC16GP0ZpZLuQj6+ujYivfozSyXchH0AG3Vso/Rm1kuNRX0ktZL2ilpt6Rbx+n3LkkhqTPV9rFkvZ2Sfnkqij4bHbWKL1VsZrk0YdBLKgKbgKuANcB1ktaM0q8K3AI8mGpbA2wALgbWA/8teb4Z11Ets//IKY+ONbPcaWaP/jJgd0TsiYheYAtwzSj9PgHcCaR3m68BtkTEqYh4GtidPN+M66hVONE3wFGPjjWznGkm6JcDz6Xm9yZtQyStBVZGxD1num6y/o2SuiR1dXd3N1X4mWr3vWPNLKcm/WWspAJwF/Chs32OiNgcEZ0R0dnW1jbZkkbVXvWgKTPLp1ITffYBK1PzK5K2hipwCfA9SQDLgK2Srm5i3RnT0bgMgr+QNbOcaWaPfjuwWtIFklqpf7m6tbEwIl6MiKURsSoiVgEPAFdHRFfSb4OksqQLgNXAD6f8t2hC496x+71Hb2Y5M+EefUT0S9oI3AsUgbsjYoekO4CuiNg6zro7JH0deALoB26OiIEpqv2MLC6XWNRa9KEbM8udZg7dEBHbgG0j2m4bo++VI+b/APiDs6xvSnXUKj50Y2a5k5uRsZDcUtBn3ZhZzuQr6KsVH6M3s9zJVdB31MocOOp7x5pZvuQs6Cuc7BvkyEmPjjWz/MhV0LcP3WnKx+nNLD/yFfTJnaZ8AxIzy5NcBX3H0KAp79GbWX7kKuiH7x3rPXozy49cBf2iconF5ZL36M0sV3IV9FAfNNXtY/RmliO5C/qOqm8Sbmb5kr+gr5V9vRszy5XcBX17reJ7x5pZruQv6KtlevsHOXLCo2PNLB9yF/RD59L78I2Z5UTugn74XHoHvZnlQ+6CvqPmm4SbWb7kLujbfZNwM8uZ3AX9wtYS1UrJe/RmlhtNBb2k9ZJ2Stot6dZRlt8k6XFJj0j6B0lrkvYWSV9Klj0p6WNT/QucjfZq2cfozSw3Jgx6SUVgE3AVsAa4rhHkKV+NiEsj4nXAp4G7kvb3AOWIuBT4V8AHJa2aotrPWket4ksVm1luNLNHfxmwOyL2REQvsAW4Jt0hIo6kZhcBjdFIASySVAIWAL1Auu+s6Kj5Mghmlh/NBP1y4LnU/N6k7TSSbpb0FPU9+t9Jmv8aOAY8D/wz8EcR8cIo694oqUtSV3d39xn+CmeuvVrmgEfHmllOTNmXsRGxKSIuAj4KfDxpvgwYAF4BXAB8SNKFo6y7OSI6I6Kzra1tqkoaU3utQu/AIC+e6Jv21zIzm23NBP0+YGVqfkXSNpYtwLXJ9PXA30REX0QcAH4AdJ5NoVOpo+YbkJhZfjQT9NuB1ZIukNQKbAC2pjtIWp2afTuwK5n+Z+DNSZ9FwBXAjydb9GT5loJmlieliTpERL+kjcC9QBG4OyJ2SLoD6IqIrcBGSW8F+oDDwA3J6puAP5e0AxDw5xHx2HT8ImfCl0EwszyZMOgBImIbsG1E222p6VvGWK+H+imWc0p7NbkMgk+xNLMcyN3IWIAFrUVqlRIHvEdvZjmQy6CHxrn03qM3s+zLbdC318oc8IXNzCwHchv09ZuEe4/ezLIvt0HfXqtw4OhJj441s8zLb9BXy/QNBIePe3SsmWVbboN+6E5TPk5vZhmX46D3ZRDMLB9yHPS+DIKZ5UNug74tuQyCB02ZWdblNugrLUXOWdDiyyCYWeblNuihfpzeh27MLOtyHfTtHjRlZjmQ76Cvlen2oRszy7hcB31HMjp2cNCjY80su/Id9EOjY3tnuxQzs2mT66BvHzqX3odvzCy7ch30jdGxvgyCmWVZroN+6JaC3qM3swxrKuglrZe0U9JuSbeOsvwmSY9LekTSP0hak1r2s5Lul7Qj6VOZyl9gMtprvkm4mWXfhEEvqQhsAq4C1gDXpYM88dWIuDQiXgd8GrgrWbcEfBm4KSIuBq4E5sx1gculIucu9OhYM8u2ZvboLwN2R8SeiOgFtgDXpDtExJHU7CKgcb7iOuCxiHg06XcoIgYmX/bUqd9pynv0ZpZdzQT9cuC51PzepO00km6W9BT1PfrfSZpfDYSkeyU9LOkjo72ApBsldUnq6u7uPrPfYJLaa2X2e4/ezDJsyr6MjYhNEXER8FHg40lzCXgD8N7k8R2S3jLKupsjojMiOtva2qaqpKa0Vyu+gqWZZVozQb8PWJmaX5G0jWULcG0yvRf4fkQcjIjjwDZg7dkUOl06kssgeHSsmWVVM0G/HVgt6QJJrcAGYGu6g6TVqdm3A7uS6XuBSyUtTL6Y/UXgicmXPXU6ahX6B4MXPDrWzDKqNFGHiOiXtJF6aBeBuyNih6Q7gK6I2ApslPRW6mfUHAZuSNY9LOku6huLALZFxD3T9LuclY7UKZZLF5dnuRozs6k3YdADRMQ26odd0m23paZvGWfdL1M/xXJOamsMmjp6iotnuRYzs+mQ65GxkLoMgr+QNbOMyn3QN+4d6wubmVlW5T7oy6Ui5y1s8aApM8us3Ac9NG5A4j16M8smBz3169L7GL2ZZZWDnvqdpnyM3syyykFPcpPwnlMMeHSsmWWQg576MfqBweCFYx4da2bZ46Bn+E5TPvPGzLLIQc/wnaZ871gzyyIHPfVDN+B7x5pZNjnogbbFHh1rZtnloAdaSwWWLGplvw/dmFkGOegTbdWyB02ZWSY56BO+DIKZZZWDPtFRK/v0SjPLJAd9oqNWofuoR8eaWfY46BPt1TKDAYeO+fCNmWWLgz7R7nPpzSyjmgp6Sesl7ZS0W9Ktoyy/SdLjkh6R9A+S1oxYfr6kHkkfnqrCp1pj0JSP05tZ1kwY9JKKwCbgKmANcN3IIAe+GhGXRsTrgE8Dd41Yfhfw7Smod9q0+5aCZpZRzezRXwbsjog9EdELbAGuSXeIiCOp2UXA0Deakq4FngZ2TL7c6dO4d6yvd2NmWdNM0C8HnkvN703aTiPpZklPUd+j/52kbTHwUeD3x3sBSTdK6pLU1d3d3WztU6qlWGDp4lbv0ZtZ5kzZl7ERsSkiLqIe7B9Pmm8H/iQieiZYd3NEdEZEZ1tb21SVdMbaq76loJllT6mJPvuAlan5FUnbWLYAn02mLwfeLenTwLnAoKSTEfGnZ1PsdGuvlT061swyp5mg3w6slnQB9YDfAFyf7iBpdUTsSmbfDuwCiIg3pvrcDvTM1ZAH6KhWeOJfjkzc0cxsHpkw6COiX9JG4F6gCNwdETsk3QF0RcRWYKOktwJ9wGHghukserp01Moc7DlF/8AgpaKHGJhZNjSzR09EbAO2jWi7LTV9SxPPcfuZFjfT2mqVZHRs79B59WZm8513W1NWnrcAgK888CwRvuaNmWWDgz7ljavbeOfa5Xzmu7v51N/82GFvZpnQ1KGbvCgWxB+9++dY0FLk8/93Dyd6B7j91y6mUNBsl2ZmdtYc9CMUCuKT117ConKJzd/fw/HeAT71zkv95ayZzVsO+lFI4mNXvZaFrUX+89/u4kTvAH/yG6+jteSwN7P5x0E/Bkn87ltfzaLWEn+w7UlO9g2w6b1rqbQUZ7s0M7Mz4l3UCfzWmy7kk9dewnd3HuADf7GdY6f6Z7skM7Mz4qBvwr+94pX88Xt+jgf2HOL9d/+QF0/0zXZJZmZNc9A36Z1rV7Dp+rU8tvenvPcLD/DCsd7ZLsnMrCkO+jNw1aUvZ/P7Otm1v4ff+Pz9vtKlmc0LDvoz9EuvbefPf/Nfs++nJ3jP5+9n7+Hjs12Smdm4HPRn4RcuWsqX/8PlvHCsl1//3P08ffDYbJdkZjYmB/1ZWnv+efzVb13Byf5B3vO5+9n5k6OzXZKZ2agc9JNwyfJz+PoHr6Ag2LD5fh7f++Jsl2Rm9hIO+kl6VXuVb9z0eha2lrj+zx6g65kXZrskM7PTOOinwCuXLOIbN72etmqZ933xh/xg98HZLsnMbIiDfoq84twFfO2Dr+eVSxbyvi8+yK9/7n6+8Pd7ePaQv6g1s9mluXbN9c7Ozujq6prtMs7ai8f7uPsHT3PfE/t58vn6/Wdf01Fl3cUdrFuzjEuW15B82WMzm1qSHoqIzlGXOeinz3MvHOe+J/Zz346fsP2ZFxgMePk5Fdat6eBta5Zx+YUvo8WXPzazKTDpoJe0Hvgv1G8O/oWI+NSI5TcBNwMDQA9wY0Q8IeltwKeAVqAX+E8R8d3xXitLQZ/2wrFevvvjA9y34yd8f1c3J/sGqVVKvPm17ay7eBlvenUbi8u+mKiZnZ1JBb2kIvBPwNuAvcB24LqIeCLVpxYRR5Lpq4Hfjoj1kn4e2B8R/yLpEuDeiFg+3utlNejTTvQO8Pe7uvnOE/v52yf3c/h4H63FAv/mVUtYd/Ey3vIz7bRXfXNyM2veeEHfzC7kZcDuiNiTPNkW4BpgKOgbIZ9YBETS/qNU+w5ggaRyRJw6s18hWxa0Fll38TLWXbyM/oFBHnr2cP0QzxM/4e+++TgSXPKKc1jdvpgL2xZxYVv9cdWSRb4evpmdsWaCfjnwXGp+L3D5yE6Sbgb+I/XDNG8e5XneBTw8WshLuhG4EeD8889voqTsKBULXH7hEi6/cAkff/vPsHP/Ue7bsZ8Hnz7E/XsO8c0f7RvqK8HycxfUg3/pIi5KbQSW1Sr+ktfMRjVlB4UjYhOwSdL1wMeBGxrLJF0M3AmsG2PdzcBmqB+6maqa5htJvHZZjdcuqwGrATh2qp+nDx5jz8Fj7OnuYU/3MfYc7KHrmRc43jswtO6CliIXLF009BfARW2LOP9lC1m6uExbtey/BMxyrJmg3wesTM2vSNrGsgX4bGNG0grgW8D7I+KpsykyzxaVS1yy/BwuWX7Oae0Rwf4jp9jT3cNTqY3Ao3t/yj2PP8/Ir14WtRZZWi2zZFErSxeXWbK4TNviVpYsLifz9fali1s5Z0GL/zowy5Bmgn47sFrSBdQDfgNwfbqDpNURsSuZfTuwK2k/F7gHuDUifjBlVRuSWHZOhWXnVPiFVy09bdnJvgGePXScvYePc6inl+6eUxzq6eVgzykOHTvFs4eO8/A/H+bQsd6XbBAASgUNBf95C1upVkosLpeoVlpYXClRGzFfrZSoJvPVSomFrUVvKMzmkAmDPiL6JW0E7qV+euXdEbFD0h1AV0RsBTZKeivQBxxm+LDNRuBVwG2Sbkva1kXEgan+RWxYpaXIa5ZVec2y6rj9BgaDw8eTDUCyITjY2CAk04eP97L/yEl6TvVz9GQ/PU3cM7cghjYEjeBf0FpkQUuRBa0lFrQUWNBSpNJoa6kvryTTC5P29PJyS4FyqUhrqUC5VKBUkDcmZk3ygCk7I4ODQU9vPz0nG8Hfx5GTp88fTabrP32c6BvgRO/A6Y/J9Kn+wbOqQ4JyqUBrsUC5pZg8JvOl0zcKjceWYoGWUoGWgmgpFigVC7QWRamYLCsqeSxQKorW5DG9rFSot5UKOn26WEgek/b0dNLHGyabTpM9vdJsSKEgapUWapWWKXm+wcHgZP8Ax3vrwX+yb+AlG4bjvQP0Dgxyqq/xOFh/7B+kt3+QU/0DqenhtuO9/Rw+PtzeNzBI30DQNzBIfzLdO3B2G5qzUSzUA7+Y/Jw+XaBQgFKhUG9T0l4UBZ3ed+hHolAQBdWfu6DT24ceC1BUfUPTWLfe9/T2ghheb5TnbqwjDdcnMbSsoPqyQqOt0Jgfbhurv8boU59+6fMAQ/UVJMRLn0sIFYaXN55PGq0t2xthB73NqkJBLGwtsbB1dj6KEcHAYNQ3AIOD9PUP0j8Y9CaP9Y1DegMR9A/Wl/UPBP0DyfRgY1nyMzA49Lwj+wwkfQaj/jgwEAxEqj3pOzA43NaYbtQ1GMNtETAQ9fUazzM8zVDfwVSfwcFknYhRv6fJI4nTw596w9BGI7WBQOl+w9Mw3KYR6w29Rqq9sYFqTP/Sa9r5+K+umfLfzUFvuSYlh1iKsIB8noIaEQwGwxuDoQ0DwxuGZOOQ3mgMDAaRWj/dJxrzybLT+qSWn/Ycg/WRlvXl6f715ZGab/RpbOQiSD1PJM9z+nrDzz1cT6SfC5LnGV4W1J94MPUa6Y3j4Ijnr7c3XqO+fno9UnUOv95wnS8/d8G0/Bs76M1yrn4opn6IxrLJl040M8s4B72ZWcY56M3MMs5Bb2aWcQ56M7OMc9CbmWWcg97MLOMc9GZmGTfnLmomqRt4dhJPsRQ4OEXlTCfXObXmS50wf2p1nVNruut8ZUS0jbZgzgX9ZEnqGusKbnOJ65xa86VOmD+1us6pNZt1+tCNmVnGOejNzDIui0G/ebYLaJLrnFrzpU6YP7W6zqk1a3Vm7hi9mZmdLot79GZmluKgNzPLuHkZ9JLWS9opabekW0dZXpb0tWT5g5JWzXyVIGmlpL+T9ISkHZJuGaXPlZJelPRI8nPbLNX6jKTHkxpecnd21X0meU8fk7R2Fmp8Tep9ekTSEUm/O6LPrL2fku6WdEDSP6baXibpO5J2JY/njbHuDUmfXZJumIU6/1DSj5N/229JOneMdcf9nMxAnbdL2pf69/2VMdYdNyNmoM6vpWp8RtIjY6w7M+9nDN2Sa378AEXgKeBCoBV4FFgzos9vA59LpjcAX5ulWl8OrE2mq8A/jVLrlcD/ngPv6zPA0nGW/wrwbeq3trwCeHAOfA5+Qn2QyJx4P4E3AWuBf0y1fRq4NZm+FbhzlPVeBuxJHs9Lps+b4TrXAaVk+s7R6mzmczIDdd4OfLiJz8a4GTHddY5Y/sfAbbP5fs7HPfrLgN0RsScieoEtwDUj+lwDfCmZ/mvgLZqF27xHxPMR8XAyfRR4Elg+03VMkWuA/x51DwDnSnr5LNbzFuCpiJjMKOopFRHfB14Y0Zz+LH4JuHaUVX8Z+E5EvBARh4HvAOtnss6IuC8i+pPZB4AV0/X6zRrj/WxGMxkxZcarM8mdXwf+arpevxnzMeiXA8+l5vfy0vAc6pN8eF8ElsxIdWNIDh/9PPDgKItfL+lRSd+WdPGMFjYsgPskPSTpxlGWN/O+z6QNjP2fZy68nw0dEfF8Mv0ToGOUPnPtvf0A9b/eRjPR52QmbEwOMd09xqGwufR+vhHYHxG7xlg+I+/nfAz6eUfSYuB/AL8bEUdGLH6Y+uGHnwP+K/A/Z7q+xBsiYi1wFXCzpDfNUh0TktQKXA18Y5TFc+X9fImo/60+p89nlvR7QD/wlTG6zPbn5LPARcDrgOepHxaZy65j/L35GXk/52PQ7wNWpuZXJG2j9pFUAs4BDs1IdSNIaqEe8l+JiG+OXB4RRyKiJ5neBrRIWjrDZRIR+5LHA8C3qP/5m9bM+z5TrgIejoj9IxfMlfczZX/jEFfyeGCUPnPivZX074BfBd6bbJReoonPybSKiP0RMRARg8CfjfH6c+X9LAHvBL42Vp+Zej/nY9BvB1ZLuiDZs9sAbB3RZyvQOHPh3cB3x/rgTqfk+NwXgScj4q4x+ixrfH8g6TLq/yYzulGStEhStTFN/Yu5fxzRbSvw/uTsmyuAF1OHJGbamHtJc+H9HCH9WbwB+F+j9LkXWCfpvORQxLqkbcZIWg98BLg6Io6P0aeZz8m0GvG90DvGeP1mMmImvBX4cUTsHW3hjL6f0/1t73T8UD8D5J+of7P+e0nbHdQ/pAAV6n/W7wZ+CFw4S3W+gfqf6o8BjyQ/vwLcBNyU9NkI7KB+ZsADwC/MQp0XJq//aFJL4z1N1ylgU/KePw50ztJ7uoh6cJ+TapsT7yf1jc/zQB/148L/nvp3Q/8H2AX8LfCypG8n8IXUuh9IPq+7gd+chTp3Uz+u3ficNs5aewWwbbzPyQzX+ZfJ5+8x6uH98pF1JvMvyYiZrDNp/4vG5zLVd1beT18Cwcws4+bjoRszMzsDDnozs4xz0JuZZZyD3sws4xz0ZmYZ56A3M8s4B72ZWcb9f8W/GShI3/AYAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "batch_size = 5\n",
    "num_steps = 10\n",
    "state_size = 10\n",
    "n_classes = 2\n",
    "learning_rate = 0.1\n",
    "\n",
    "x = tf.placeholder(tf.int32, [batch_size, num_steps])\n",
    "y = tf.placeholder(tf.int32, [batch_size, num_steps])\n",
    "\n",
    "#RNN的初始化状态，全设为零。注意state是与input保持一致，接下来会有concat操作，所以这里要有batch的维度。即每个样本都要有隐层状态\n",
    "init_state = tf.zeros([batch_size, state_size])\n",
    "\n",
    "#将输入转化为one-hot编码，两个类别。[batch_size, num_steps, num_classes]\n",
    "x_one_hot = tf.one_hot(x, n_classes)\n",
    "\n",
    "#将输入unstack，即在num_steps上解绑，方便给每个循环单元输入。这里可以看出RNN每个cell都处理一个batch的输入（即batch个二进制样本输入）\n",
    "rnn_inputs = tf.unstack(x_one_hot, axis=1)\n",
    "\n",
    "#定义rnn_cell的权重参数，\n",
    "cell = tf.contrib.rnn.BasicRNNCell(state_size)\n",
    "rnn_outputs, final_state = tf.contrib.rnn.static_rnn(cell, rnn_inputs, initial_state=init_state)\n",
    "\n",
    "# 定义softmax层\n",
    "with tf.variable_scope('softmax'):\n",
    "    W = tf.get_variable('W', [state_size, n_classes])\n",
    "    b = tf.get_variable('b', [n_classes])\n",
    "\n",
    "#注意，这里要将num_steps个输出全部分别进行计算其输出，然后使用softmax预测\n",
    "logits = [tf.matmul(rnn_output, W) + b for rnn_output in rnn_outputs]\n",
    "predictions = [tf.nn.softmax(logit) for logit in logits]\n",
    "\n",
    "# tf.stack（）这是一个矩阵拼接的函数，tf.unstack（）则是一个矩阵分解的函数\n",
    "y_as_lists = tf.unstack(y, num=num_steps, axis=1)\n",
    "\n",
    "losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=logit) for label, logit in zip(y_as_lists, predictions)]\n",
    "total_loss = tf.reduce_mean(losses)\n",
    "train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(total_loss)\n",
    "\n",
    "def train_network(num_epochs, num_steps, state_size, verbose=True):\n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        training_losses = []\n",
    "        \n",
    "        for idx, epoch in enumerate(gen_epochs(num_epochs, num_steps)):\n",
    "            training_loss = 0\n",
    "            # 保存每次执行后的最后状态，然后赋给下一次执行\n",
    "            training_state = np.zeros((batch_size, state_size))\n",
    "            if verbose:\n",
    "                print('\\EPOCH', idx)\n",
    "            \n",
    "            # 这是具体获得数据的部分\n",
    "            for step, (X, Y) in enumerate(epoch):\n",
    "                tr_losses, training_loss_, training_state, _ = sess.run([losses, total_loss, final_state, train_step], \n",
    "                                                                        feed_dict={x:X, y:Y, init_state: training_state})\n",
    "                \n",
    "                training_loss += training_loss_\n",
    "                if step % 100 == 0 and step > 0:\n",
    "                    if verbose:\n",
    "                        print(\"Average loss at step\", step,\"for last 100 steps:\", training_loss / 100)\n",
    "                    \n",
    "                    training_losses.append(training_loss / 100)\n",
    "                    training_loss = 0\n",
    "            \n",
    "            return training_losses\n",
    "\n",
    "training_losses = train_network(1, num_steps, state_size)\n",
    "plt.plot(training_losses)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
